# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the module is used to process images."""

import copy
import random

import numpy as np
import mmcv
import cv2
from PIL import Image

from dataset.builder import build_transforms
from utils.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.PIPELINE)
class ImgRgbToBgr:
    """ Convert rgb to bgr """

    def __call__(self, results):
        image = results['image']
        image_bgr = image.copy()
        image_bgr[:, :, 0] = image[:, :, 2]
        image_bgr[:, :, 1] = image[:, :, 1]
        image_bgr[:, :, 2] = image[:, :, 0]
        results['image'] = image_bgr

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class RandomExpand:
    """expand operation for image"""

    def __init__(self, mean=(0, 0, 0), to_rgb=True, expand_ratio=0,
                 ratio_range=(1, 4)):
        if to_rgb:
            self.mean = mean[::-1]
        else:
            self.mean = mean
        self.expand_ratio = expand_ratio
        self.min_ratio, self.max_ratio = ratio_range

    def __call__(self, results):
        expand = (random.random() < self.expand_ratio)
        if not expand:
            return results

        img = results.get("image")
        boxes = results.get("bboxes")

        h, w, c = img.shape
        ratio = random.uniform(self.min_ratio, self.max_ratio)
        expand_img = np.full((int(h * ratio), int(w * ratio), c),
                             self.mean).astype(img.dtype)
        left = int(random.uniform(0, w * ratio - w))
        top = int(random.uniform(0, h * ratio - h))
        expand_img[top:top + h, left:left + w] = img
        img = expand_img
        results['image'] = img
        boxes += np.tile((left, top), 2)
        results['bboxes'] = boxes
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Resize:
    """resize operation for image"""

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height

    def __call__(self, results):
        img = results.get("image")
        gt_bboxes = results.get("bboxes")

        img_data = img
        img_data, w_scale, h_scale = mmcv.imresize(
            img_data, (self.img_width, self.img_height), return_scale=True)
        scale_factor = np.array(
            [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
        img_shape = (self.img_height, self.img_width, 1.0)
        img_shape = np.asarray(img_shape, dtype=np.float32)

        gt_bboxes = gt_bboxes * scale_factor

        gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
        gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)

        results['image'] = img_data
        results['image_shape'] = img_shape
        results['bboxes'] = gt_bboxes

        return results


def _rand(a=0., b=1.):
    return np.random.rand() * (b - a) + a


def bbox_iou(bbox_a, bbox_b, offset=0):
    """Calculate Intersection-Over-Union(IOU) of two bounding boxes.

    Parameters
    ----------
    bbox_a : numpy.ndarray
        An ndarray with shape :math:`(N, 4)`.
    bbox_b : numpy.ndarray
        An ndarray with shape :math:`(M, 4)`.
    offset : float or int, default is 0
        The ``offset`` is used to control the whether the width(or height) is computed as
        (right - left + ``offset``).
        Note that the offset must be 0 for normalized bboxes, whose ranges are in ``[0, 1]``.

    Returns
    -------
    numpy.ndarray
        An ndarray with shape :math:`(N, M)` indicates IOU between each pairs of
        bounding boxes in `bbox_a` and `bbox_b`.

    """
    if bbox_a.shape[1] < 4 or bbox_b.shape[1] < 4:
        raise IndexError("Bounding boxes axis 1 must have at least length 4")

    tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
    br = np.minimum(bbox_a[:, None, 2:4], bbox_b[:, 2:4])

    area_i = np.prod(br - tl + offset, axis=2) * (tl < br).all(axis=2)
    area_a = np.prod(bbox_a[:, 2:4] - bbox_a[:, :2] + offset, axis=1)
    area_b = np.prod(bbox_b[:, 2:4] - bbox_b[:, :2] + offset, axis=1)
    return area_i / (area_a[:, None] + area_b - area_i)


def pil_image_reshape(interp):
    """Reshape pil image."""
    reshape_type = {
        0: Image.NEAREST,
        1: Image.BILINEAR,
        2: Image.BICUBIC,
        3: Image.NEAREST,
        4: Image.LANCZOS,
    }
    return reshape_type[interp]


def get_interp_method(interp, sizes=()):
    """
    Get the interpolation method for resize functions.
    The major purpose of this function is to wrap a random interp method selection
    and a auto-estimation method.

    Note:
        When shrinking an image, it will generally look best with AREA-based
        interpolation, whereas, when enlarging an image, it will generally look best
        with Bicubic or Bilinear.

    Args:
        interp (int): Interpolation method for all resizing operations.

            - 0: Nearest Neighbors Interpolation.
            - 1: Bilinear interpolation.
            - 2: Bicubic interpolation over 4x4 pixel neighborhood.
            - 3: Nearest Neighbors. Originally it should be Area-based, as we cannot find Area-based,
              so we use NN instead. Area-based (resampling using pixel area relation).
              It may be a preferred method for image decimation, as it gives moire-free results.
              But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default).
            - 4: Lanczos interpolation over 8x8 pixel neighborhood.
            - 9: Cubic for enlarge, area for shrink, bilinear for others.
            - 10: Random select from interpolation method mentioned above.

        sizes (tuple): Format should like (old_height, old_width, new_height, new_width),
            if None provided, auto(9) will return Area(2) anyway. Default: ()

    Returns:
        int, interp method from 0 to 4.
    """
    if interp == 9:
        if sizes:
            assert len(sizes) == 4
            oh, ow, nh, nw = sizes
            if nh > oh and nw > ow:
                return 2
            if nh < oh and nw < ow:
                return 0
            return 1
        return 2
    if interp == 10:
        return random.randint(0, 4)
    if interp not in (0, 1, 2, 3, 4):
        raise ValueError('Unknown interp method %d' % interp)
    return interp


@ClassFactory.register(ModuleType.PIPELINE)
class ResizeWithinMultiScales:
    """
    Crop an image randomly with bounding box constraints.

    Args:
        max_boxes : max boxes number.
        jitter: jitter
        max_trial: max trial
        flip : flip ratio.
        rgb : expand rgb color
        use_constraints : bool
    """

    def __init__(self, max_boxes, jitter, max_trial,
                 flip_ratio=0.5,
                 rgb=(128, 128, 128),
                 use_constraints=False):
        """Constructor for ResizeWithinMultiScales"""
        self.use_constraints = use_constraints
        self.max_boxes = max_boxes
        self.jitter = jitter
        self.max_trial = max_trial
        self.flip_ratio = flip_ratio
        self.rgb = rgb

    def _is_iou_satisfied_constraint(self, min_iou, max_iou, box, crop_box):
        iou = bbox_iou(box, crop_box)
        return min_iou <= iou.min() and max_iou >= iou.max()

    def _choose_candidate_by_constraints(self, max_trial, input_w, input_h,
                                         image_w, image_h,
                                         jitter, box, use_constraints):
        """Choose candidate by constraints."""
        if use_constraints:
            constraints = (
                (0.1, None),
                (0.3, None),
                (0.5, None),
                (0.7, None),
                (0.9, None),
                (None, 1),
            )
        else:
            constraints = (
                (None, None),
            )
        # add default candidate
        candidates = [(0, 0, input_w, input_h)]
        for constraint in constraints:
            min_iou, max_iou = constraint
            min_iou = -np.inf if min_iou is None else min_iou
            max_iou = np.inf if max_iou is None else max_iou

            for _ in range(max_trial):
                # box_data should have at least one box
                new_ar = float(input_w) / float(input_h) * _rand(
                    1 - jitter, 1 + jitter) / _rand(1 - jitter, 1 + jitter)
                scale = _rand(0.25, 2)

                if new_ar < 1:
                    nh = int(scale * input_h)
                    nw = int(nh * new_ar)
                else:
                    nw = int(scale * input_w)
                    nh = int(nw / new_ar)

                dx = int(_rand(0, input_w - nw))
                dy = int(_rand(0, input_h - nh))

                if box.size > 0:
                    t_box = copy.deepcopy(box)
                    t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(
                        image_w) + dx
                    t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(
                        image_h) + dy

                    crop_box = np.array((0, 0, input_w, input_h))
                    if not self._is_iou_satisfied_constraint(min_iou, max_iou,
                                                             t_box,
                                                             crop_box[
                                                                 np.newaxis]):
                        continue
                    else:
                        candidates.append((dx, dy, nw, nh))
                else:
                    raise Exception("!!! annotation box is less than 1")
        return candidates

    def _correct_bbox_by_candidates(self, candidates,
                                    input_w, input_h, image_w, image_h,
                                    flip, box, box_data, allow_outside_center):
        """Calculate correct boxes."""
        while candidates:
            if len(candidates) > 1:
                # ignore default candidate which do not crop
                candidate = candidates.pop(
                    np.random.randint(1, len(candidates)))
            else:
                candidate = candidates.pop(
                    np.random.randint(0, len(candidates)))
            dx, dy, nw, nh = candidate
            t_box = copy.deepcopy(box)
            t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(
                image_w) + dx
            t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(
                image_h) + dy
            if flip:
                t_box[:, [0, 2]] = input_w - t_box[:, [2, 0]]

            if allow_outside_center:
                pass
            else:
                t_box = t_box[
                    np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. >= 0.,
                                   (t_box[:, 1] + t_box[:, 3]) / 2. >= 0.)]
                t_box = t_box[
                    np.logical_and((t_box[:, 0] + t_box[:, 2]) / 2. <= input_w,
                                   (t_box[:, 1] + t_box[:, 3]) / 2. <= input_h)]

            # recorrect x, y for case x,y < 0 reset to zero, after dx and dy, some box can smaller than zero
            t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
            # recorrect w,h not higher than input size
            t_box[:, 2][t_box[:, 2] > input_w] = input_w
            t_box[:, 3][t_box[:, 3] > input_h] = input_h
            box_w = t_box[:, 2] - t_box[:, 0]
            box_h = t_box[:, 3] - t_box[:, 1]
            # discard invalid box: w or h smaller than 1 pixel
            t_box = t_box[np.logical_and(box_w > 1, box_h > 1)]

            if t_box.shape[0] > 0:
                # break if number of find t_box
                box_data[: len(t_box)] = t_box
                return box_data, candidate
        raise Exception('all candidates can not satisfied re-correct bbox')

    def __call__(self, results):
        image = copy.deepcopy(results['image'])
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        box = results['bboxes']
        image_h, image_w = results['image_shape']
        input_h, input_w = results['resize_size']

        np.random.shuffle(box)
        if len(box) > self.max_boxes:
            box = box[:self.max_boxes]
        flip = _rand() < self.flip_ratio

        box_data = np.zeros((self.max_boxes, 4))

        candidates = self._choose_candidate_by_constraints(
            use_constraints=False,
            max_trial=10,
            input_w=input_w,
            input_h=input_h,
            image_w=image_w,
            image_h=image_h,
            jitter=self.jitter,
            box=box)
        box_data, candidate = self._correct_bbox_by_candidates(
            candidates=candidates,
            input_w=input_w,
            input_h=input_h,
            image_w=image_w,
            image_h=image_h,
            flip=flip,
            box=box,
            box_data=box_data,
            allow_outside_center=True)
        dx, dy, nw, nh = candidate
        interp = get_interp_method(interp=10)
        image = image.resize((nw, nh), pil_image_reshape(interp))

        # place image, gray color as back graoud
        new_image = Image.new('RGB', (input_w, input_h), self.rgb)
        new_image.paste(image, (dx, dy))
        image = new_image
        image = np.array(image)
        results['image'] = image
        results['bboxes'] = box_data
        results['image_shape'] = np.array(image.shape[:2], np.int32)
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class PilResize:
    """
    Resize for Pil operation.

    Args:
        resize_size : resize shape.
        interp : type of PIL interp.
    """

    def __init__(self, resize_size, interp=9):
        """Constructor for PilResize"""
        self.output_h, self.output_w = resize_size
        self.interp = interp

    def __call__(self, results):
        """ Do resize. """
        image = results['image']
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        ori_w, ori_h = image.size
        interp = get_interp_method(interp=self.interp,
                                   sizes=(ori_h, ori_w, self.output_h, self.output_w))
        image = image.resize((self.output_w, self.output_h),
                             pil_image_reshape(interp))
        image = np.array(image)
        results['image'] = image
        results['image_shape'] = np.array(image.shape[:2], np.int32)
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Rescale:
    """
    rescale operation for image

    Args:
        img_width : image width.
        img_height : image height.
    """

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height

    def __call__(self, results):
        img = results.get("image")
        gt_bboxes = results.get("bboxes")

        img_data, scale_factor = mmcv.imrescale(img, (
            self.img_width, self.img_height), return_scale=True)
        if img_data.shape[0] > self.img_height:
            img_data, scale_factor2 = mmcv.imrescale(img_data, (
                self.img_height, self.img_height),
                                                     return_scale=True)
            scale_factor = scale_factor * scale_factor2

        gt_bboxes = gt_bboxes * scale_factor
        gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0,
                                     img_data.shape[1] - 1)
        gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0,
                                     img_data.shape[0] - 1)

        pad_h = self.img_height - img_data.shape[0]
        pad_w = self.img_width - img_data.shape[1]
        assert ((pad_h >= 0) and (pad_w >= 0))

        pad_img_data = np.zeros((self.img_height, self.img_width, 3)).astype(
            img_data.dtype)
        pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data

        img_shape = (self.img_height, self.img_width, 1.0)
        img_shape = np.asarray(img_shape, dtype=np.float32)

        results['image'] = img_data
        results['image_shape'] = img_shape
        results['bboxes'] = gt_bboxes

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Normalize:
    """
    imnormalize operation for image

    Args:
        mean: mean value.
        std: standard deviation.
        to_rgb: whether do rgb transformer or not.
    """
    def __init__(self, mean=(123.675, 116.28, 103.53),
                 std=(58.395, 57.12, 57.375), to_rgb=True):
        self.mean = np.array(mean)
        self.std = np.array(std)
        self.to_rgb = to_rgb

    def __call__(self, results):
        img = results['image']
        img_data = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
        img_data = img_data.astype(np.float32)
        results['image'] = img_data
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class RandomFlip:
    """
    random flip operation

    Args:
        ratio : flip ratio.
    """
    def __init__(self, ratio):
        self.ratio = ratio

    def __call__(self, results):
        flip = (np.random.rand() < self.ratio)
        if not flip:
            return results

        img = results.get("image")
        gt_bboxes = results.get("bboxes")

        # flip image
        img_data = img
        img_data = mmcv.imflip(img_data)

        # flip bboxes
        flipped = copy.deepcopy(gt_bboxes)
        _, w, _ = img_data.shape
        # flip bboxes horizontal
        flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
        flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1

        results['image'] = img_data
        results['bboxes'] = flipped

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class RandomPilFlip:
    """
    Use PIL Apis to do flip operations.

    Args:
        ratio: flip ratio.
        direction : Only support "FLIP_LEFT_RIGHT" and "FLIP_TOP_BOTTOM"
    """

    def __init__(self, ratio=0.5, direction="FLIP_LEFT_RIGHT"):
        """Constructor for PilFilp"""
        if direction == "FLIP_LEFT_RIGHT":
            self.flip = Image.FLIP_LEFT_RIGHT
        else:
            self.flip = Image.FLIP_TOP_BOTTOM
        self.ratio = ratio

    def __call__(self, results):
        if _rand() > self.ratio:
            return results
        image = results['image']
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        image.transpose(Image.FLIP_LEFT_RIGHT)
        image = np.array(image)
        results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Transpose:
    """ Transpose """
    def __init__(self, perm=(2, 0, 1)):
        self.perm = perm

    def __call__(self, results):
        """ Transpose operation for image. """
        img = results.get("image")
        img_data = img.transpose(self.perm).copy()

        results['image'] = img_data
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class EvalFormat:
    """ Eval Format """
    def __call__(self, data_tuple):
        image = data_tuple[0]
        image_id = data_tuple[1]
        results = {'image': image, 'image_id': image_id}
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Format:
    """Format the input data

    Args:
       pad_max_number : pad config
    Examples:
    """
    def __init__(self, pad_max_number=None):
        self.pad_max_number = pad_max_number

    def __call__(self, data_tuple):
        image = data_tuple[0]
        annotations = data_tuple[1]  # default box, label, iscrowd
        gt_box = annotations[:, :4]
        gt_label = annotations[:, 4]
        gt_iscrowd = annotations[:, 5]
        image_shape = np.array(image.shape[:2], np.int32)

        if self.pad_max_number is not None:
            pad_max_number = self.pad_max_number
            gt_box_new = np.pad(gt_box,
                                ((0, pad_max_number - annotations.shape[0]),
                                 (0, 0)),
                                mode="constant", constant_values=0)
            gt_label_new = np.pad(gt_label,
                                  ((0, pad_max_number - annotations.shape[0])),
                                  mode="constant", constant_values=-1)
            gt_iscrowd_new = np.pad(gt_iscrowd,
                                    (
                                        (0, pad_max_number - annotations.shape[
                                            0])),
                                    mode="constant", constant_values=1)
            gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(
                np.int32)
        else:
            gt_box_new = gt_box
            gt_label_new = gt_label
            gt_iscrowd_new_revert = (~(gt_iscrowd.astype(np.bool))).astype(
                np.int32)

        result = {'image': image,
                  'image_shape': image_shape,
                  'bboxes': gt_box_new,
                  'labels': gt_label_new,
                  'valid_num': gt_iscrowd_new_revert}

        return result


@ClassFactory.register(ModuleType.PIPELINE)
class Collect:
    """Collect output image data.Convert dict to tuple

    Args:
        output_orders (list) : output order
        output_type_dict (dict) : output types

    Examples:
    """
    _np_type_dict = {'bool': np.bool,
                     'int8': np.int8,
                     'int16': np.int16,
                     'int32': np.int32,
                     'int64': np.int64,
                     'uint8': np.uint8,
                     'uint16': np.uint16,
                     'uint32': np.uint32,
                     'uint64': np.uint64,
                     'float16': np.float16,
                     'float32': np.float32,
                     'float64': np.float64}

    def __init__(self, output_orders, output_type_dict=None):
        self.output_type_dict = output_type_dict
        self.output_orders = output_orders

    def np_type_cast(self, results):
        if self.output_type_dict is None:
            return
        for k in self.output_type_dict:
            if k in results:
                results[k] = results[k].astype(
                    self._np_type_dict[self.output_type_dict[k]])

    def __call__(self, results):
        self.np_type_cast(results)
        result = []
        for k in self.output_orders:
            if k in results:
                result.append(results[k])
            else:
                result.append([])
        result = [results[k] for k in self.output_orders]
        return tuple(result)


@ClassFactory.register(ModuleType.PIPELINE)
class ColorDistortion:
    """
    Color distortion.

    Args:
        hue : hue val
        sat : saturation
        val :
    """

    def __init__(self, hue, saturation, value):
        """Constructor for ColorDistortion"""
        self.hue = hue
        self.saturation = saturation
        self.value = value

    def __call__(self, results):
        image = results['image']
        hue = _rand(-self.hue, self.hue)
        sat = _rand(1, self.saturation) if _rand() < .5 else 1 / _rand(1, self.saturation)
        val = _rand(1, self.value) if _rand() < .5 else 1 / _rand(1, self.value)

        x = cv2.cvtColor(image, cv2.COLOR_RGB2HSV_FULL)
        x = x / 255.
        x[..., 0] += hue
        x[..., 0][x[..., 0] > 1] -= 1
        x[..., 0][x[..., 0] < 0] += 1
        x[..., 1] *= sat
        x[..., 2] *= val
        x[x > 1] = 1
        x[x < 0] = 0
        x = x * 255.
        x = x.astype(np.uint8)
        image = cv2.cvtColor(x, cv2.COLOR_HSV2RGB_FULL)
        results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class ConvertGrayToColor:
    """ Convert gray image to color image"""
    def __call__(self, results):
        """convert gray 2 color"""
        image = results['image']
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=-1)
            image = np.concatenate([image, image, image], axis=-1)
            results['image'] = image
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class PerBatchCocoCollect:
    """ Collect for yolo """
    def __call__(self, data_tuple):
        image = data_tuple[0]
        anno = data_tuple[1]
        resize_size = data_tuple[2]
        gt_box = anno[:, :4]
        gt_label = anno[:, 4]
        image_shape = np.array(image.shape[:2], np.int32)
        results = {
            'image': image,
            'annotation': anno,
            'image_shape': image_shape,
            'bboxes': gt_box,
            'labels': gt_label,
            'resize_size': resize_size
        }
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class YoloBboxPreprocess:
    """ Data bbox preprocess for yolo """
    def __init__(self,
                 anchors,
                 anchor_mask,
                 num_classes,
                 label_smooth,
                 label_smooth_factor,
                 iou_threshold,
                 max_boxes):
        """Constructor for YoloCollate"""
        self.anchors = anchors
        self.num_classes = num_classes
        self.label_smooth = label_smooth
        self.label_smooth_factor = label_smooth_factor
        self.iou_threshold = iou_threshold
        self.max_boxes = max_boxes
        self.anchor_mask = anchor_mask

    def build_true_boxes(self, results):
        """ rebuild boxes list. """
        bboxes = results['bboxes'].tolist()
        labels = results['labels'].tolist()
        extend_labels = [0] * (len(bboxes) - len(labels))
        labels = labels + extend_labels
        true_boxes = []
        for bbox, label in zip(bboxes, labels):
            # tmp been formatted as [x_min y_min x_max y_max, label].
            tmp = []
            tmp.extend(bbox)
            tmp.append(int(label))
            true_boxes.append(tmp)
        return true_boxes

    def __call__(self, results):
        anchors = np.array(self.anchors)
        num_layers = anchors.shape[0] // 3
        true_boxes = self.build_true_boxes(results)
        true_boxes = np.array(true_boxes, dtype='float32')
        # input_shape is [h, w]
        input_shape = np.array(results['image_shape'], dtype='int32')
        boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
        # trans to box center point
        boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
        # true_boxes = [xywh]
        true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
        true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
        # grid_shape [h, w]
        grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
        # y_true [gridy, gridx]
        y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(
            self.anchor_mask[l]), 5 + self.num_classes), dtype='float32') for l in range(num_layers)]

        # expand dimension for iou calculating.
        anchors = np.expand_dims(anchors, 0)
        anchors_max = anchors / 2.
        anchors_min = -anchors_max
        # remove lines which is zeros.
        valid_mask = boxes_wh[..., 0] > 0

        wh = boxes_wh[valid_mask]
        if wh.size != 0:
            wh = np.expand_dims(wh, -2)
            # wh shape : [box_num, 1, 2]
            # move to original point to compare, and choose the best layer-anchor to set
            boxes_max = wh / 2.
            boxes_min = -boxes_max

            intersect_min = np.maximum(boxes_min, anchors_min)
            intersect_max = np.minimum(boxes_max, anchors_max)
            intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
            intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
            box_area = wh[..., 0] * wh[..., 1]
            anchor_area = anchors[..., 0] * anchors[..., 1]
            iou = intersect_area / (box_area + anchor_area - intersect_area)

            # find max iou of ground truth box in anchor boxes. calculate coordinate of ground truth box.
            best_anchor = np.argmax(iou, axis=-1)
            for t, n in enumerate(best_anchor):
                for l in range(num_layers):
                    if n in self.anchor_mask[l]:
                        i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')  # grid_y
                        j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')  # grid_x

                        k = self.anchor_mask[l].index(n)
                        c = true_boxes[t, 4].astype('int32')
                        y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
                        y_true[l][j, i, k, 4] = 1.

                        # label-smooth
                        if self.label_smooth:
                            sigma = self.label_smooth_factor / (self.num_classes - 1)
                            y_true[l][j, i, k, 5:] = sigma
                            y_true[l][j, i, k, 5 + c] = 1 - self.label_smooth_factor
                        else:
                            y_true[l][j, i, k, 5 + c] = 1.

            threshold_anchor = (iou > self.iou_threshold)
            for t in range(threshold_anchor.shape[0]):
                for n in range(threshold_anchor.shape[1]):
                    if not threshold_anchor[t][n]:
                        continue
                    for l in range(num_layers):
                        if n not in self.anchor_mask[l]:
                            continue

                        i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')  # grid_y
                        j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')  # grid_x

                        k = self.anchor_mask[l].index(n)
                        c = true_boxes[t, 4].astype('int32')
                        y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
                        y_true[l][j, i, k, 4] = 1.

                        # label smooth
                        if self.label_smooth:
                            sigma = self.label_smooth_factor / (self.num_classes - 1)
                            y_true[l][j, i, k, 5:] = sigma
                            y_true[l][
                                j, i, k, 5 + c] = 1 - self.label_smooth_factor
                        else:
                            y_true[l][j, i, k, 5 + c] = 1.

        # pad_gt_boxes for avoiding dynamic shape
        pad_gt_box0 = np.zeros(shape=[self.max_boxes, 4], dtype=np.float32)
        pad_gt_box1 = np.zeros(shape=[self.max_boxes, 4], dtype=np.float32)
        pad_gt_box2 = np.zeros(shape=[self.max_boxes, 4], dtype=np.float32)

        mask0 = np.reshape(y_true[0][..., 4:5], [-1])
        gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
        # gt_box [boxes, [x,y,w,h]]
        gt_box0 = gt_box0[mask0 == 1]
        # gt_box0: get all boxes which have object
        if gt_box0.shape[0] < self.max_boxes:
            pad_gt_box0[:gt_box0.shape[0]] = gt_box0
        else:
            pad_gt_box0 = gt_box0[:self.max_boxes]
        # gt_box0.shape[0]: total number of boxes in gt_box0
        # top N of pad_gt_box0 is real box, and after are pad by zero

        mask1 = np.reshape(y_true[1][..., 4:5], [-1])
        gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
        gt_box1 = gt_box1[mask1 == 1]
        if gt_box1.shape[0] < self.max_boxes:
            pad_gt_box1[:gt_box1.shape[0]] = gt_box1
        else:
            pad_gt_box1 = gt_box1[:self.max_boxes]

        mask2 = np.reshape(y_true[2][..., 4:5], [-1])
        gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])

        gt_box2 = gt_box2[mask2 == 1]
        if gt_box2.shape[0] < self.max_boxes:
            pad_gt_box2[:gt_box2.shape[0]] = gt_box2
        else:
            pad_gt_box2 = gt_box2[:self.max_boxes]
        results['bbox1'] = y_true[0]
        results['bbox2'] = y_true[1]
        results['bbox3'] = y_true[2]
        results['gt_box1'] = pad_gt_box0
        results['gt_box2'] = pad_gt_box1
        results['gt_box3'] = pad_gt_box2
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class PerBatchMap:
    """
    Batch data preprocess map

    Args:
        pipeline : preprocess config
    """

    def __init__(self, out_orders, multi_scales=None, output_type_dict=None,
                 pipeline=None):
        """Constructor for PerBatchMap"""
        self.out_orders = out_orders
        self.multi_scales = multi_scales
        self.output_type_dict = output_type_dict
        self.preprocess_pipeline = pipeline
        if self.preprocess_pipeline is not None:
            self.opts = build_transforms(self.preprocess_pipeline)

    def __call__(self, imgs, anno, x1, x2, x3, x4, x5, batch_info):
        args = (imgs, anno, x1, x2, x3, x4, x5, batch_info)
        if self.preprocess_pipeline is None:
            return args
        return self.data_augment(*args)

    def merge_batch(self, results, result_dict):
        for k in results:
            if k in result_dict:
                result_dict[k].append(results[k])
            else:
                result_dict[k] = [results[k]]

    def tuple_batch(self, result_dict):
        collect = Collect(self.out_orders, self.output_type_dict)
        return collect(result_dict)

    def data_augment(self, *args):
        """ data augment pipeline. """
        if self.preprocess_pipeline is None:
            return args

        images = args[0]
        annos = args[1]
        result_dict = {}
        resize_size = random.choice(self.multi_scales)

        for img, anno in zip(images, annos):
            results = (img, anno, resize_size)
            for opt in self.opts:
                results = opt(results)
            self.merge_batch(results, result_dict)

        return self.tuple_batch(result_dict)
